# import torch
import numpy as np
from sklearn.model_selection import train_test_split
from typing import NamedTuple, Optional

class PVTrainDataSet(NamedTuple):
    treatment: np.ndarray
    treatment_proxy: np.ndarray
    outcome_proxy: np.ndarray
    outcome: np.ndarray
    backdoor: Optional[np.ndarray]

class PVTestDataSet(NamedTuple):
    treatment: np.ndarray
    treatment_proxy: np.ndarray
    outcome_proxy: np.ndarray
    outcome: np.ndarray
    backdoor: Optional[np.ndarray]


def split_train_data(train_data: PVTrainDataSet, split_ratio=0.5):
    if split_ratio < 0.0:
        return train_data, train_data

    n_data = train_data[0].shape[0]
    idx_train_1st, idx_train_2nd = train_test_split(np.arange(n_data), train_size=split_ratio,random_state=42)

    def get_data(data, idx):
        return data[idx] if data is not None else None

    train_1st_data = PVTrainDataSet(*[get_data(data, idx_train_1st) for data in train_data])
    train_2nd_data = PVTrainDataSet(*[get_data(data, idx_train_2nd) for data in train_data])

    return train_1st_data, train_2nd_data